Add streaming inference API#46
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughAdded streaming generation interfaces and implementations to TimeSeriesLLM, OpenTSLMSP, and OpenTSLMFlamingo; refactored Flamingo to build and use token embeddings with explicit vision encoding and media-conditioning; adjusted tokenizer embedding resize behavior; added unit tests and streaming utilities. Changes
Sequence Diagram(s)sequenceDiagram
actor Client
participant Model as LLM Wrapper
participant Encoder as Vision/LLM Encoder
participant Streamer as TextIteratorStreamer
participant Thread as WorkerThread
Client->>Model: stream_generate(batch, max_new_tokens, ...)
Model->>Model: _validate_streaming_batch(batch)
Model->>Model: build inputs_embeds & attention_mask
alt Flamingo multimodal path
Model->>Encoder: _encode_vision_x(images)
Model->>Model: _condition_media_locations(input_ids)
end
Model->>Streamer: instantiate TextIteratorStreamer(skip_prompt=True)
Model->>Thread: start daemon running generate_fn
Thread->>Encoder: call model.generate(inputs_embeds=..., streamer=Streamer)
Encoder-->>Streamer: push token chunks
loop stream tokens
Streamer-->>Model: yield text chunk
Model-->>Client: yield text chunk
end
Model->>Thread: join thread (finally)
Model->>Model: clear conditioned layers (if any)
Model-->>Client: finish
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/opentslm/model/llm/OpenTSLMFlamingo.py`:
- Around line 296-308: The sync generation path may leak conditioned media state
because self.model.lang_encoder.clear_conditioned_layers() is only called after
a successful lang_encoder.generate; wrap the conditioned cleanup in a finally
block so clear_conditioned_layers() always runs even if generate() raises.
Specifically, in _forward_with_embeddings()/the sync branch where you call
self.model._encode_vision_x(...), self._condition_media_locations(input_ids) and
then gen_ids = self.model.lang_encoder.generate(...), ensure you call
self.model.lang_encoder.clear_conditioned_layers() inside a finally (or
try/finally) surrounding the generate() call so conditioned layers are cleared
on success or exception. Use the same cleanup pattern as stream_generate() to
avoid leaking state between requests.
- Around line 182-191: The loop in OpenTSLMFlamingo._condition_media_locations
currently accesses decoder layers via
self.model.lang_encoder.get_decoder().layers which hard-codes a Llama-style
accessor; change it to use the model-specific accessor
self.model.lang_encoder._get_decoder_layers() so the dynamic
decoder_layers_attr_name set by lang_encoder.set_decoder_layers_attr_name() is
respected (see the pattern used in
TimeSeriesFlamingoWithTrainableEncoder._get_decoder_layers()); update the loop
to iterate over _get_decoder_layers() and leave the rest of the method
(media_locations/attend_previous and per-layer calls to
condition_media_locations and condition_attend_previous) unchanged.
In `@src/opentslm/model/llm/OpenTSLMSP.py`:
- Around line 7-8: The code uses TextIteratorStreamer and the mean_resizing
argument on resize_token_embeddings which are not available in transformers
4.25.0; update the project dependency to a transformers version that includes
these features (at least >=4.28.0 for TextIteratorStreamer and a more recent
release that provides mean_resizing—use a pinned range like
"transformers>=4.28.0,<5" or the specific version you have validated), then run
tests; update the requirements/pyproject entry and any lockfile, and re-run CI;
verify imports of TextIteratorStreamer and calls to
AutoModelForCausalLM.resize_token_embeddings(...) in the OpenTSLMSP initializer
and the resize_token_embeddings usage in OpenTSLMFlamingo (lines referenced in
the review) work without errors.
In `@src/opentslm/model/llm/TimeSeriesLLM.py`:
- Around line 49-54: The abstract TimeSeriesLLM.stream_prompt signature needs to
accept generation kwargs so callers typed against TimeSeriesLLM won't get
unexpected keyword errors; update the method signature on
TimeSeriesLLM.stream_prompt to include a generate_kwargs parameter (e.g.,
generate_kwargs: Optional[Dict[str, Any]] = None) and adjust the
NotImplementedError message if desired; ensure the symbol name stream_prompt on
class TimeSeriesLLM matches the subclasses OpenTSLMSP and OpenTSLMFlamingo so
their existing parameters like temperature or stream_timeout are supported when
typed against the base class.
- Around line 64-88: The unconditional thread.join() blocks cancel/timeout
paths; change _iterate_streamer so it does not perform a blocking join: remove
the unconditional thread.join() and either skip joining entirely (relying on
daemon=True) or use a non-blocking join (thread.join(0) or
thread.join(timeout=0)) so the caller can return immediately on
timeout/early-consume; keep the error check (if error is not None: raise error)
after the non-blocking join behavior so exceptions from runner are still
propagated without waiting for a stuck generation thread.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: d2888ce6-9d11-47c4-834b-2fd361f07c2d
📒 Files selected for processing (4)
src/opentslm/model/llm/OpenTSLMFlamingo.pysrc/opentslm/model/llm/OpenTSLMSP.pysrc/opentslm/model/llm/TimeSeriesLLM.pytest/test_stream_inference.py
- Use _get_decoder_layers() instead of hard-coded get_decoder().layers for multi-architecture support in _condition_media_locations - Wrap clear_conditioned_layers() in finally block in generate() to prevent state leaks on errors - Add **generate_kwargs to base TimeSeriesLLM.stream_prompt() signature to match subclass interfaces - Bump transformers minimum version from >=4.25 to >=4.46 for TextIteratorStreamer and mean_resizing support
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/opentslm/model/llm/OpenTSLMFlamingo.py`:
- Around line 111-121: The constructor currently unconditionally calls
_infer_decoder_layers_attr_name and overwrites any explicit
decoder_layers_attr_name; change the logic to honor a caller-provided override
by only calling _infer_decoder_layers_attr_name when decoder_layers_attr_name is
not provided (e.g., is None/empty), and otherwise use the provided
decoder_layers_attr_name when calling lang_encoder.set_decoder_layers_attr_name;
keep the existing resize_token_embeddings call as-is.
- Around line 170-180: The current sequence calls
self.model._encode_vision_x(...) and self._condition_media_locations(...) which
mutate per-layer conditioned state but only clears that state in a finally after
invoking FlamingoLMMixin/lang_encoder.forward; if either setup step raises the
clear is skipped and state leaks. Wrap the setup and forward so that
clear_conditioned_layers() is registered to run before any mutation can occur:
call clear_conditioned_layers() in a finally that covers _encode_vision_x and
_condition_media_locations (i.e., start the try/finally before those calls),
then perform inputs_embeds building and the call to super(...).forward inside
the try, and always call self.model.lang_encoder.clear_conditioned_layers() in
the finally. Apply the same pattern to the other occurrences referenced (around
the blocks at ~297-310 and ~342-355).
- Around line 329-357: The torch.inference_mode() context is currently outside
run_generation and therefore not active in worker threads; move the
torch.inference_mode() call inside the run_generation function so that
_encode_vision_x, _condition_media_locations and model.lang_encoder.generate
execute under inference mode. Specifically, wrap the body of run_generation (the
calls to self.model._encode_vision_x, self._condition_media_locations, and
self.model.lang_encoder.generate, and the finally block that calls
self.model.lang_encoder.clear_conditioned_layers) in a with
torch.inference_mode(): block so that streaming generation invoked via
_iterate_streamer and the TextIteratorStreamer runs without creating autograd
state.
- Around line 301-313: The generation path currently passes inputs_embeds alone
to self.model.lang_encoder.generate which (on transformers >=4.46) returns only
new tokens, so the existing slice answer_only_ids = gen_ids[:,
input_ids.shape[1]:] and TextIteratorStreamer(skip_prompt=True) drop real
generated tokens; fix by either (A) pass input_ids along with inputs_embeds into
self.model.lang_encoder.generate so gen_ids includes the prompt (update the
generate call at the inputs_embeds usage and keep answer_only_ids slicing and
TextIteratorStreamer(skip_prompt=True)), or (B) treat gen_ids as "new tokens
only" by removing the slice (set answer_only_ids = gen_ids) and change
TextIteratorStreamer to skip_prompt=False (update the code paths that create
answer_only_ids and the TextIteratorStreamer instantiation). Ensure changes
touch the generate call (self.model.lang_encoder.generate), answer_only_ids
assignment, and TextIteratorStreamer(skip_prompt=...) usages so behavior is
consistent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 86a257c5-784f-4c72-b892-64919f231460
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (3)
pyproject.tomlsrc/opentslm/model/llm/OpenTSLMFlamingo.pysrc/opentslm/model/llm/TimeSeriesLLM.py
✅ Files skipped from review due to trivial changes (1)
- pyproject.toml
- Move _encode_vision_x and _condition_media_locations inside try blocks so clear_conditioned_layers always runs even if setup steps raise - Add torch.inference_mode() inside run_generation thread since inference_mode is thread-local and does not propagate to new threads
♻️ Current situation & Problem
The current inference API (
generate()/eval_prompt()) returns results only after the entire sequence has been generated. For interactive applications and real-time monitoring dashboards, users must wait for the full generation to complete before seeing any output. This prevents use cases that benefit from incremental token delivery, such as live clinical decision support or streaming analytics pipelines.⚙️ Release Notes
stream_generate()andstream_prompt()methods toTimeSeriesLLM,OpenTSLMSP, andOpenTSLMFlamingofor token-by-token streaming inference via Python iterators._validate_streaming_batch,_iterate_streamer) in theTimeSeriesLLMbase class handle threading and error propagation.resize_token_embeddingsto passmean_resizing=False, preventing failures when the model is initialized on meta tensors in low-memory environments.Breaking changes
OpenTSLMFlamingo.generate()andcompute_loss()now useinputs_embedsinstead oflang_x/vision_xkwargs internally. Public API is unchanged.📚 Documentation
The new
stream_generate()andstream_prompt()methods follow the same conventions as the existinggenerate()andeval_prompt()methods. In-line docstrings and type annotations are provided. The base class defines the interface and shared utilities; each subclass implements the architecture-specific generation logic.✅ Testing
A comprehensive test suite is added in
test/test_stream_inference.py(346 lines) covering:OpenTSLMSPandOpenTSLMFlamingostream_promptend-to-end flow (prompt conversion, normalization, eval mode)mean_resizing=Falseis passed during init for both architecturesCode of Conduct & Contributing Guidelines
By creating and submitting this pull request, you agree to follow our Code of Conduct and Contributing Guidelines: